{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T01:44:35.907330Z",
     "start_time": "2020-06-13T01:44:34.844828Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No GPU found\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "\n",
    "# 使用 CPU\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n",
    "\n",
    "if tf.test.gpu_device_name():\n",
    "    print('GPU found')\n",
    "else:\n",
    "    print(\"No GPU found\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 遮挡"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 填充"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T02:05:07.118926Z",
     "start_time": "2020-06-13T02:05:07.115203Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  83,   91,    1,  645, 1253,  927],\n",
       "       [  73,    8, 3215,   55,  927,    0],\n",
       "       [ 711,  632,   71,    0,    0,    0]], dtype=int32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_inputs = [\n",
    "    [83, 91, 1, 645, 1253, 927],\n",
    "    [73, 8, 3215, 55, 927],\n",
    "    [711, 632, 71],\n",
    "]\n",
    "\n",
    "padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs,\n",
    "                                                              padding='post')\n",
    "padded_inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T03:35:11.390569Z",
     "start_time": "2020-06-13T03:35:11.388800Z"
    }
   },
   "source": [
    "## 生成遮挡"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 嵌入层生成遮挡\n",
    "> 设置属性`mask_zero=True`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T01:47:59.397925Z",
     "start_time": "2020-06-13T01:47:59.390492Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ True  True  True  True  True  True]\n",
      " [ True  True  True  True  True False]\n",
      " [ True  True  True False False False]], shape=(3, 6), dtype=bool)\n"
     ]
    }
   ],
   "source": [
    "embedding = layers.Embedding(\n",
    "    input_dim=5000,\n",
    "    output_dim=16,\n",
    "    mask_zero=True,  # 遮挡索引 0 \n",
    ")\n",
    "\n",
    "masked_output = embedding(padded_inputs)\n",
    "\n",
    "print(masked_output._keras_mask)\n",
    "# 幕后会创建一个遮挡矩阵"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 遮挡层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T02:05:12.460027Z",
     "start_time": "2020-06-13T02:05:12.443773Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ True  True  True  True  True  True]\n",
      " [ True  True  True  True  True False]\n",
      " [ True  True  True False False False]], shape=(3, 6), dtype=bool)\n"
     ]
    }
   ],
   "source": [
    "masking_layer = layers.Masking()\n",
    "\n",
    "unmasked_embedding = tf.cast(\n",
    "    tf.tile(tf.expand_dims(padded_inputs, axis=-1), [1, 1, 10]), tf.float32)\n",
    "# 模拟批量的词嵌入查找结果，batch，seq_len,embed_size\n",
    "\n",
    "masked_embedding = masking_layer(unmasked_embedding)\n",
    "print(masked_embedding._keras_mask)\n",
    "# mask 结果：batch,seq_len，表征序列中哪个词被遮挡"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 遮挡的传播"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sequential 模型接口\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T02:11:29.684546Z",
     "start_time": "2020-06-13T02:11:29.203884Z"
    }
   },
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "    layers.Embedding(\n",
    "        input_dim=5000,\n",
    "        output_dim=16,\n",
    "        mask_zero=True, # 设置遮挡\n",
    "    ),\n",
    "    layers.LSTM(32) # LSTM 层会自动接受遮挡，忽略填充的值\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Functional 模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tf.keras.Input(shape=(None, ), dtype='int32')\n",
    "x = layers.Embedding(\n",
    "    input_dim=5000,\n",
    "    output_dim=16,\n",
    "    mask_zero=True,  # 设置遮挡\n",
    ")(inputs)\n",
    "outputs = layers.LSTM(32)(x)\n",
    "\n",
    "model = tf.keras.Model(inputs, outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 自定义层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T02:17:39.370811Z",
     "start_time": "2020-06-13T02:17:39.289480Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(32, 32), dtype=float32, numpy=\n",
       "array([[-4.2033270e-03, -3.2198273e-03, -3.2618330e-03, ...,\n",
       "        -5.1559320e-05, -7.5503844e-03, -2.6671903e-03],\n",
       "       [-1.5757827e-03, -8.7853119e-04, -5.4395376e-03, ...,\n",
       "         1.3437866e-03,  4.2481534e-03,  1.0843794e-03],\n",
       "       [-2.6922526e-03,  6.8654576e-03,  3.0769468e-03, ...,\n",
       "        -6.4012818e-03, -8.7724405e-04,  5.9047979e-03],\n",
       "       ...,\n",
       "       [ 3.6077171e-03,  2.4239554e-03, -1.6135219e-03, ...,\n",
       "        -4.7010854e-03, -7.1414036e-04, -1.5066147e-03],\n",
       "       [ 1.7981452e-04,  8.5821534e-03, -1.8688419e-04, ...,\n",
       "        -3.0656459e-03, -1.5515659e-03, -1.7525189e-03],\n",
       "       [ 2.2413780e-03,  3.6969897e-03,  2.5818348e-03, ...,\n",
       "         4.7936742e-03, -1.7710224e-03, -1.6857670e-03]], dtype=float32)>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "class MyLayer(layers.Layer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(MyLayer, self).__init__(**kwargs)\n",
    "        self.embedding = layers.Embedding(\n",
    "            input_dim=5000,\n",
    "            output_dim=16,\n",
    "            mask_zero=True,  # 设置遮挡\n",
    "        )\n",
    "        self.lstm = layers.LSTM(32)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x = self.embedding(inputs)\n",
    "\n",
    "        # 创建遮挡\n",
    "        # 也可以手动创建，只需要形状 batch,timesteps 的 bool 类型张量\n",
    "        mask = self.embedding.compute_mask(inputs)\n",
    "\n",
    "        output = self.lstm(\n",
    "            x,\n",
    "            mask=mask,  # 传递遮挡参数，自动忽略被遮挡的单词\n",
    "        )\n",
    "        return output\n",
    "\n",
    "\n",
    "layer = MyLayer()\n",
    "x = np.random.random((32, 10)) * 100\n",
    "x = x.astype('int32')\n",
    "layer(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 支持 mask 的自定义层\n",
    "> 在层中定义 layer.compute_mask 方法，以默认的遮挡和给定输入，生成新的遮挡"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T03:14:40.749287Z",
     "start_time": "2020-06-13T03:14:40.745132Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 6, 10), dtype=float32, numpy=\n",
       "array([[[8.300e+01, 8.300e+01, 8.300e+01, 8.300e+01, 8.300e+01,\n",
       "         8.300e+01, 8.300e+01, 8.300e+01, 8.300e+01, 8.300e+01],\n",
       "        [9.100e+01, 9.100e+01, 9.100e+01, 9.100e+01, 9.100e+01,\n",
       "         9.100e+01, 9.100e+01, 9.100e+01, 9.100e+01, 9.100e+01],\n",
       "        [1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00,\n",
       "         1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00],\n",
       "        [6.450e+02, 6.450e+02, 6.450e+02, 6.450e+02, 6.450e+02,\n",
       "         6.450e+02, 6.450e+02, 6.450e+02, 6.450e+02, 6.450e+02],\n",
       "        [1.253e+03, 1.253e+03, 1.253e+03, 1.253e+03, 1.253e+03,\n",
       "         1.253e+03, 1.253e+03, 1.253e+03, 1.253e+03, 1.253e+03],\n",
       "        [9.270e+02, 9.270e+02, 9.270e+02, 9.270e+02, 9.270e+02,\n",
       "         9.270e+02, 9.270e+02, 9.270e+02, 9.270e+02, 9.270e+02]],\n",
       "\n",
       "       [[7.300e+01, 7.300e+01, 7.300e+01, 7.300e+01, 7.300e+01,\n",
       "         7.300e+01, 7.300e+01, 7.300e+01, 7.300e+01, 7.300e+01],\n",
       "        [8.000e+00, 8.000e+00, 8.000e+00, 8.000e+00, 8.000e+00,\n",
       "         8.000e+00, 8.000e+00, 8.000e+00, 8.000e+00, 8.000e+00],\n",
       "        [3.215e+03, 3.215e+03, 3.215e+03, 3.215e+03, 3.215e+03,\n",
       "         3.215e+03, 3.215e+03, 3.215e+03, 3.215e+03, 3.215e+03],\n",
       "        [5.500e+01, 5.500e+01, 5.500e+01, 5.500e+01, 5.500e+01,\n",
       "         5.500e+01, 5.500e+01, 5.500e+01, 5.500e+01, 5.500e+01],\n",
       "        [9.270e+02, 9.270e+02, 9.270e+02, 9.270e+02, 9.270e+02,\n",
       "         9.270e+02, 9.270e+02, 9.270e+02, 9.270e+02, 9.270e+02],\n",
       "        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00,\n",
       "         0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00]],\n",
       "\n",
       "       [[7.110e+02, 7.110e+02, 7.110e+02, 7.110e+02, 7.110e+02,\n",
       "         7.110e+02, 7.110e+02, 7.110e+02, 7.110e+02, 7.110e+02],\n",
       "        [6.320e+02, 6.320e+02, 6.320e+02, 6.320e+02, 6.320e+02,\n",
       "         6.320e+02, 6.320e+02, 6.320e+02, 6.320e+02, 6.320e+02],\n",
       "        [7.100e+01, 7.100e+01, 7.100e+01, 7.100e+01, 7.100e+01,\n",
       "         7.100e+01, 7.100e+01, 7.100e+01, 7.100e+01, 7.100e+01],\n",
       "        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00,\n",
       "         0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],\n",
       "        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00,\n",
       "         0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],\n",
       "        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00,\n",
       "         0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00]]],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T03:18:34.905786Z",
     "start_time": "2020-06-13T03:18:34.885395Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ True  True  True  True  True  True]\n",
      " [ True  True  True  True  True False]\n",
      " [ True  True  True False False False]], shape=(3, 6), dtype=bool)\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_keras_mask'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-32-540f502317e9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmasked_embedding\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_keras_mask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 默认的 mask，不满足要求\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_keras_mask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 无 mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_keras_mask'"
     ]
    }
   ],
   "source": [
    "class Default(tf.keras.layers.Layer):\n",
    "    def call(self, inputs):\n",
    "        # 该层会生成张量列表格，默认的 遮挡 masking(inputs) 将不能满足要求\n",
    "        return tf.split(inputs, 2, axis=1)\n",
    "    \n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        if mask is None:\n",
    "            return None\n",
    "        # 默认的 mask\n",
    "        return mask   \n",
    "    \n",
    "a, b = Default()(masked_embedding)\n",
    "print(a._keras_mask) # 默认的 mask，不满足要求\n",
    "print(b._keras_mask) # 无 mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T03:18:38.983221Z",
     "start_time": "2020-06-13T03:18:38.971142Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ True  True  True]\n",
      " [ True  True  True]\n",
      " [ True  True  True]], shape=(3, 3), dtype=bool)\n",
      "tf.Tensor(\n",
      "[[ True  True  True]\n",
      " [ True  True False]\n",
      " [False False False]], shape=(3, 3), dtype=bool)\n"
     ]
    }
   ],
   "source": [
    "class TemporalSplit(tf.keras.layers.Layer):\n",
    "    def call(self, inputs):\n",
    "        # 该层会生成张量列别，默认的 遮挡 masking(inputs) 将不能满足要求\n",
    "        return tf.split(inputs, 2, axis=1)\n",
    "\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        if mask is None:\n",
    "            return None\n",
    "        # 将 mask 也拆分成列表\n",
    "        return tf.split(mask, 2, axis=1)\n",
    "\n",
    "first_half, second_half = TemporalSplit()(masked_embedding)\n",
    "print(first_half._keras_mask)\n",
    "print(second_half._keras_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T03:25:31.722021Z",
     "start_time": "2020-06-13T03:25:31.712343Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[False  True  True False False  True False  True  True  True]\n",
      " [ True  True  True  True  True  True False  True  True  True]\n",
      " [ True  True  True  True  True  True  True  True  True  True]], shape=(3, 10), dtype=bool)\n"
     ]
    }
   ],
   "source": [
    "class CustomEmbedding(tf.keras.layers.Layer):\n",
    "    def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs):\n",
    "        super(CustomEmbedding, self).__init__(**kwargs)\n",
    "        self.input_dim = input_dim\n",
    "        self.output_dim = output_dim\n",
    "        self.mask_zero = mask_zero\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.embeddings = self.add_weight(\n",
    "            shape=(self.input_dim, self.output_dim),\n",
    "            initializer='random_normal',\n",
    "            dtype='float32',\n",
    "        )\n",
    "\n",
    "    def call(self, inputs):\n",
    "        return tf.nn.embedding_lookup(self.embeddings, inputs)\n",
    "\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        if not self.mask_zero:\n",
    "            return None\n",
    "        return tf.not_equal(inputs, 0)\n",
    "\n",
    "\n",
    "layer = CustomEmbedding(10, 32, mask_zero=True)\n",
    "x = np.random.random((3, 10)) * 9\n",
    "x = x.astype('int32')\n",
    "\n",
    "y = layer(x)\n",
    "mask = layer.compute_mask(x)\n",
    "\n",
    "print(mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 带 mask 参数的层\n",
    "> 添加 `mask=None` 参数，"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaskConsumer(tf.keras.layers.Layer):\n",
    "    def call(self, inputs, mask=None):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结：\n",
    "- masking 让层知道应该跳过或忽略输入序列的特定的时间步\n",
    "- 一些层，自动生成 masking，如：指定mask_zero=True 的 Embedding 层，Masking 层\n",
    "- 一些层是 mask 的消费者，在其 `__call__` 方法中公开 mask 参数，如 RNN 层\n",
    "- 函数式及序列式模型接口中，遮挡信息会自动传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 注意力中的遮挡\n",
    "## 填充产生的遮挡\n",
    "> 输入序列可能是填充为等长"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T04:23:42.025727Z",
     "start_time": "2020-06-13T04:23:42.020169Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 1, 1, 3), dtype=float32, numpy=\n",
       "array([[[[0., 0., 1.]]],\n",
       "\n",
       "\n",
       "       [[[0., 0., 0.]]],\n",
       "\n",
       "\n",
       "       [[[0., 1., 1.]]]], dtype=float32)>"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def create_padding_mask(seq):\n",
    "    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)\n",
    "    # 添加额外的维度来将填充加到\n",
    "    # 注意力对数（logits）。\n",
    "    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch, 1, 1, seq_len)\n",
    "\n",
    "\n",
    "x = tf.constant([[7, 6, 0], [1, 2, 3], [1, 0, 0]])\n",
    "create_padding_mask(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T04:19:43.397608Z",
     "start_time": "2020-06-13T04:19:43.383454Z"
    }
   },
   "outputs": [],
   "source": [
    "def scaled_dot_product_attention(q, k, v, mask):\n",
    "    \"\"\"\n",
    "    q: batch, q_len, q_dim\n",
    "    k: batch, k_len, q_dim\n",
    "    v: batch, k_len, v_dim\n",
    "    \n",
    "    mask: batch, 1,1, k_len\n",
    "    \"\"\"\n",
    "    attention_score = tf.matmul(q, k, transpose_b=True)  # batch,q_len,k_len\n",
    "    dk = tf.cast(tf.shape(k)[-1], tf.float32)\n",
    "    scaled_attetntion_score = attention_score / tf.math.sqrt(dk)\n",
    "\n",
    "    # 被 mask 的位置对应的权重乘以 -1e9，sofmax时，该处的概率就变为 0 了\n",
    "    if mask is not None:\n",
    "        scaled_score += (mask * -1e9)\n",
    "\n",
    "    attention_distribution = tf.nn.softmax(scaled_attetntion_score, -1)\n",
    "    output = tf.matmul(attention_distribution, v)  # batch,q_len,v_dim\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前瞻遮挡\n",
    "- 前瞻遮挡（look-ahead mask）用于遮挡一个序列中的后续标记（future tokens）。\n",
    "> 这意味着要预测第三个词，将仅使用第一个和第二个词。与此类似，预测第四个词，仅使用第一个，第二个和第三个词，依此类推。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-13T04:22:40.695115Z",
     "start_time": "2020-06-13T04:22:40.691053Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 3), dtype=float32, numpy=\n",
       "array([[0., 1., 1.],\n",
       "       [0., 0., 1.],\n",
       "       [0., 0., 0.]], dtype=float32)>"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def create_look_ahead_mask(size):\n",
    "    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)\n",
    "    return mask  # (seq_len, seq_len)\n",
    "\n",
    "\n",
    "x = tf.random.uniform((1, 3))\n",
    "temp = create_look_ahead_mask(x.shape[1])\n",
    "temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
